from typing import Any, Dict, Optional

from axelrod.action import Action

from axelrod.evolvable_player import (
    EvolvablePlayer,
    InsufficientParametersError,
    copy_lists,
    crossover_lists,
)

from axelrod.player import Player

C, D = Action.C, Action.D

def is_stochastic_matrix(m, ep=1e-8) -> bool:
    """Checks that the matrix m (a list of lists) is a stochastic matrix."""
    for i in range(len(m)):
        for j in range(len(m[i])):
            if (m[i][j] < 0) or (m[i][j] > 1):
                return False
        s = sum(m[i])
        if abs(1.0 - s) > ep:
            return False
    return True

def normalize_vector(vec):
    s = sum(vec)
    vec = [v / s for v in vec]
    return vec

def mutate_row(row, mutation_probability, rng):
    """, crossover_lists_of_lists
    Given a row of probabilities, randomly change each entry with probability
    `mutation_probability` (a value between 0 and 1).  If changing, then change
    by a value randomly (uniformly) chosen from [-0.25, 0.25] bounded by 0 and
    100%.
    """
    randoms = rng.random(len(row))
    for i in range(len(row)):
        if randoms[i] < mutation_probability:
            ep = rng.uniform(-1, 1) / 4
            row[i] += ep
            if row[i] < 0:
                row[i] = 0
            if row[i] > 1:
                row[i] = 1
    return row

class HMMPlayer(Player):
    """
    Abstract base class for Hidden Markov Model players.

    Names

        - HMM Player: Original name by Marc Harper
    """

    name = "HMM Player"

    classifier = {
        "memory_depth": 1,
        "stochastic": True,
        "long_run_time": False,
        "inspects_source": False,
        "manipulates_source": False,
        "manipulates_state": False,
    }

    def __init__(
        self,
        transitions_C=None,
        transitions_D=None,
        emission_probabilities=None,
        initial_state=0,
        initial_action=C,
    ) -> None:
        Player.__init__(self)
        if not transitions_C:
            transitions_C = [[1]]
            transitions_D = [[1]]
            emission_probabilities = [0.5]  # Not stochastic
            initial_state = 0
        self.initial_state = initial_state
        self.initial_action = initial_action
        self.hmm = SimpleHMM(
            copy_lists(transitions_C),
            copy_lists(transitions_D),
            list(emission_probabilities),
            initial_state,
        )
        assert self.hmm.is_well_formed()
        self.state = self.hmm.state
        self.classifier["stochastic"] = self.is_stochastic()

    def is_stochastic(self) -> bool:
        """Determines if the player is stochastic."""
        # If the transitions matrices and emission_probabilities are all 0 or 1
        # Then the player is stochastic
        values = set(self.hmm.emission_probabilities)
        for m in [self.hmm.transitions_C, self.hmm.transitions_D]:
            for row in m:
                values.update(row)
        if not values.issubset({0, 1}):
            return True
        return False

    def strategy(self, opponent: Player) -> Action:
        """Actual strategy definition that determines player's action."""
        if len(self.history) == 0:
            return self.initial_action
        else:
            action = self.hmm.move(opponent.history[-1])
            # Record the state for testing purposes, this isn't necessary
            # for the strategy to function
            self.state = self.hmm.state
            return action

    def set_seed(self, seed=None):
        super().set_seed(seed=seed)
        # Share RNG with HMM
        # The evolvable version of the class needs to manually share the rng with the HMM
        # after initialization.
        try:
            self.hmm._random = self._random
        except AttributeError:
            pass

class EvolvableHMMPlayer(HMMPlayer, EvolvablePlayer):
    """Evolvable version of HMMPlayer."""

    name = "EvolvableHMMPlayer"

    def __init__(
        self,
        transitions_C=None,
        transitions_D=None,
        emission_probabilities=None,
        initial_state=0,
        initial_action=C,
        num_states=None,
        mutation_probability=None,
        seed: Optional[int] = None,
    ) -> None:
        EvolvablePlayer.__init__(self, seed=seed)
        (
            transitions_C,
            transitions_D,
            emission_probabilities,
            initial_state,
            initial_action,
            num_states,
            mutation_probability,
        ) = self._normalize_parameters(
            transitions_C,
            transitions_D,
            emission_probabilities,
            initial_state,
            initial_action,
            num_states,
            mutation_probability,
        )
        self.mutation_probability = mutation_probability
        HMMPlayer.__init__(
            self,
            transitions_C=transitions_C,
            transitions_D=transitions_D,
            emission_probabilities=emission_probabilities,
            initial_state=initial_state,
            initial_action=initial_action,
        )
        self.hmm._random = self._random
        self.overwrite_init_kwargs(
            transitions_C=transitions_C,
            transitions_D=transitions_D,
            emission_probabilities=emission_probabilities,
            initial_state=initial_state,
            initial_action=initial_action,
            num_states=num_states,
            mutation_probability=mutation_probability,
        )

    def _normalize_parameters(
        self,
        transitions_C=None,
        transitions_D=None,
        emission_probabilities=None,
        initial_state=None,
        initial_action=None,
        num_states=None,
        mutation_probability=None,
    ):
        if not (
            (transitions_C and transitions_D and emission_probabilities)
            and (initial_state is not None)
            and (initial_action is not None)
        ):
            if not num_states:
                raise InsufficientParametersError(
                    "Insufficient Parameters to instantiate EvolvableHMMPlayer"
                )
            (
                transitions_C,
                transitions_D,
                emission_probabilities,
                initial_state,
                initial_action,
            ) = self.random_params(num_states)
        # Normalize types of various matrices
        for m in [transitions_C, transitions_D]:
            for i in range(len(m)):
                m[i] = list(map(float, m[i]))
        emission_probabilities = list(map(float, emission_probabilities))
        num_states = len(emission_probabilities)
        if mutation_probability is None:
            mutation_probability = 10 / (num_states**2)
        else:
            mutation_probability = mutation_probability
        return (
            transitions_C,
            transitions_D,
            emission_probabilities,
            initial_state,
            initial_action,
            num_states,
            mutation_probability,
        )

    def random_params(self, num_states):
        transitions_C = []
        transitions_D = []
        emission_probabilities = []
        for _ in range(num_states):
            transitions_C.append(self._random.random_vector(num_states))
            transitions_D.append(self._random.random_vector(num_states))
            emission_probabilities.append(self._random.random())
        initial_state = self._random.randint(0, num_states)
        initial_action = C
        return (
            transitions_C,
            transitions_D,
            emission_probabilities,
            initial_state,
            initial_action,
        )

    @property
    def num_states(self):
        return len(self.hmm.emission_probabilities)

    def mutate_rows(self, rows, mutation_probability):
        for i, row in enumerate(rows):
            row = mutate_row(row, mutation_probability, self._random)
            rows[i] = normalize_vector(row)
        return rows

    def mutate(self):
        transitions_C = self.mutate_rows(
            self.hmm.transitions_C, self.mutation_probability
        )
        transitions_D = self.mutate_rows(
            self.hmm.transitions_D, self.mutation_probability
        )
        emission_probabilities = mutate_row(
            self.hmm.emission_probabilities,
            self.mutation_probability,
            self._random,
        )
        initial_action = self.initial_action
        if self._random.random() < self.mutation_probability / 10:
            initial_action = self.initial_action.flip()
        initial_state = self.initial_state
        if self._random.random() < self.mutation_probability / (
            10 * self.num_states
        ):
            initial_state = self._random.randint(0, self.num_states)
        return self.create_new(
            transitions_C=transitions_C,
            transitions_D=transitions_D,
            emission_probabilities=emission_probabilities,
            initial_state=initial_state,
            initial_action=initial_action,
        )

    def crossover(self, other):
        if other.__class__ != self.__class__:
            raise TypeError(
                "Crossover must be between the same player classes."
            )
        transitions_C = crossover_lists(
            self.hmm.transitions_C, other.hmm.transitions_C, self._random
        )
        transitions_D = crossover_lists(
            self.hmm.transitions_D, other.hmm.transitions_D, self._random
        )
        emission_probabilities = crossover_lists(
            self.hmm.emission_probabilities,
            other.hmm.emission_probabilities,
            self._random,
        )
        return self.create_new(
            transitions_C=transitions_C,
            transitions_D=transitions_D,
            emission_probabilities=emission_probabilities,
        )

    def receive_vector(self, vector):
        """
        Read a serialized vector into the set of HMM parameters (less initial
        state).  Then assign those HMM parameters to this class instance.

        Assert that the vector has the right number of elements for an HMMParams
        class with self.num_states.

        Assume the first num_states^2 entries are the transitions_C matrix.  The
        next num_states^2 entries are the transitions_D matrix.  Then the next
        num_states entries are the emission_probabilities vector.  Finally the last
        entry is the initial_action.
        """

        assert len(vector) == 2 * self.num_states**2 + self.num_states + 1

        def deserialize(vector):
            matrix = []
            for i in range(self.num_states):
                row = vector[self.num_states * i : self.num_states * (i + 1)]
                row = normalize_vector(row)
                matrix.append(row)
            return matrix

        break_tc = self.num_states**2
        break_td = 2 * self.num_states**2
        break_ep = 2 * self.num_states**2 + self.num_states
        initial_state = 0
        self.hmm = SimpleHMM(
            deserialize(vector[0:break_tc]),
            deserialize(vector[break_tc:break_td]),
            normalize_vector(vector[break_td:break_ep]),
            initial_state,
        )
        self.initial_action = C if round(vector[-1]) == 0 else D
        self.initial_state = initial_state

    def create_vector_bounds(self):
        """Creates the bounds for the decision variables."""
        vec_len = 2 * self.num_states**2 + self.num_states + 1
        lb = [0.0] * vec_len
        ub = [1.0] * vec_len
        return lb, ub